# -*- coding: UTF-8 -*-
"""
    @Author:YTQ
    @Time: 2022/11/28 20:31
    Description:
    
"""
import torch

from config import *
import os
from utils.utils import get_label


class Config(object):
    """配置参数"""

    def __init__(self, obj):
        self.model_name = 'bert_CNN'
        self.bert_path = MODULE_BERT_PATH
        self.train_path = TRAIN_DATA_PATH  # 训练集
        self.dev_path = DEV_DATA_PATH  # 验证集
        self.test_path = TEST_DATA_PATH  # 测试集
        self.class_list, self.class_dir = get_label()  # 类别名单
        self.save_path = MODULE_SAVE_PATH + os.sep  # 模型训练结果
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 设备
        self.isUseGpu = True if self.device == 'cuda' else False
        self.require_improvement = 1000  # 若超过1000batch效果还没提升，则提前结束训练
        self.num_classes = len(self.class_list)  # 类别数
        self.epochs_num = obj['epoch']  # epoch数
        self.batch_size = obj['batch_size']  # mini-batch大小
        self.pad_size = obj['pad_size']  # 每句话处理成的长度(短填长切)
        self.learning_rate = 5e-5  # 学习率
        self.hidden_size = obj['hidden_size']
        self.filter_sizes = obj['filter_sizes']  # (2, 3, 4)  # 卷积核尺寸
        self.num_filters = obj['num_filters']  # 256  # 卷积核数量(channels数)
        self.dropout = 0.1
